from fastai.core import *
from fastai.torch_core import *
from fastai.vision import *
from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand

_conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)

def _conv(ni:int, nf:int, ks:int=3, stride:int=1, **kwargs):
    return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)

def custom_gan_critic(n_channels:int=3, nf:int=256, n_blocks:int=3, p:int=0.15):
    "Critic to train a `GAN`."
    layers = [
        _conv(n_channels, nf, ks=4, stride=2),
        nn.Dropout2d(p/2)]
    for i in range(n_blocks):
        layers += [
            _conv(nf, nf, ks=3, stride=1),
            nn.Dropout2d(p),
            _conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
        nf *= 2
    layers += [
        _conv(nf, nf, ks=3, stride=1),
        _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
        Flatten()]
    return nn.Sequential(*layers)

def colorize_crit_learner(data:ImageDataBunch, loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()), nf:int=256)->Learner:
    return Learner(data, custom_gan_critic(nf=nf), metrics=accuracy_thresh_expand, loss_func=loss_critic, wd=1e-3)
